import tensorflow as tf
import numpy as np
from model import build_model

# 加载数据和模型
def load_model():
    return tf.keras.models.load_model('cnn_model.h5')

def predict_image(image):
    model = load_model()
    image = np.expand_dims(image, axis=0)
    predictions = model.predict(image)
    return np.argmax(predictions)

# 测试预测
if __name__ == "__main__":
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
    test_image = test_images[0]
    prediction = predict_image(test_image)
    print(f"Predicted class: {prediction}")

